Face Generation with DCGAN¶
This notebook demonstrates face generation using a DCGAN (Deep Convolutional GAN) trained on CelebA.
What makes this fun:
- Train a GAN that generates realistic human faces
- Watch the model learn facial features progressively
- Fast training on GPU (~30 minutes for quality results)
- Generate unlimited unique faces
Why DCGAN? Stable architecture with convolutional layers, batch normalization, and proven effectiveness for image generation.
# Installation
# !pip install torch torchvision matplotlib
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed(42)
Using device: cuda
Part 1: Data Preparation¶
We'll use the CelebA dataset - 200k celebrity face images. The faces will be cropped, resized to 64x64, and normalized.
# Load CelebA dataset
image_size = 64
batch_size = 128
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize to [-1, 1]
])
# Download CelebA dataset (this may take a few minutes the first time)
train_dataset = torchvision.datasets.CelebA(
root='./data',
split='train',
download=True,
transform=transform
)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2,
pin_memory=True
)
print(f'Training samples: {len(train_dataset):,}')
print(f'Batches per epoch: {len(train_loader):,}')
print(f'Image shape: {train_dataset[0][0].shape}')
# Visualize real samples
samples = next(iter(train_loader))[0][:64]
grid = make_grid(samples, nrow=8, normalize=True, value_range=(-1, 1))
plt.figure(figsize=(12, 12))
plt.imshow(grid.permute(1, 2, 0).cpu())
plt.title('Real CelebA Face Images', fontsize=16)
plt.axis('off')
plt.tight_layout()
plt.show()
Downloading... From (original): https://drive.google.com/uc?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM From (redirected): https://drive.usercontent.google.com/download?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM&confirm=t&uuid=4f1cf598-a56a-45f9-8f98-c2bb644e80eb To: /content/data/celeba/img_align_celeba.zip 100%|██████████| 1.44G/1.44G [00:12<00:00, 120MB/s] Downloading... From: https://drive.google.com/uc?id=0B7EVK8r0v71pblRyaVFSWGxPY0U To: /content/data/celeba/list_attr_celeba.txt 100%|██████████| 26.7M/26.7M [00:00<00:00, 41.6MB/s] Downloading... From: https://drive.google.com/uc?id=1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS To: /content/data/celeba/identity_CelebA.txt 100%|██████████| 3.42M/3.42M [00:00<00:00, 36.2MB/s] Downloading... From: https://drive.google.com/uc?id=0B7EVK8r0v71pbThiMVRxWXZ4dU0 To: /content/data/celeba/list_bbox_celeba.txt 100%|██████████| 6.08M/6.08M [00:00<00:00, 336MB/s] Downloading... From: https://drive.google.com/uc?id=0B7EVK8r0v71pd0FJY3Blby1HUTQ To: /content/data/celeba/list_landmarks_align_celeba.txt 100%|██████████| 12.2M/12.2M [00:00<00:00, 363MB/s] Downloading... From: https://drive.google.com/uc?id=0B7EVK8r0v71pY0NSMzRuSXJEVkk To: /content/data/celeba/list_eval_partition.txt 100%|██████████| 2.84M/2.84M [00:00<00:00, 290MB/s]
Training samples: 162,770 Batches per epoch: 1,272 Image shape: torch.Size([3, 64, 64])
Part 2: Build DCGAN¶
DCGAN uses deep convolutional layers with batch normalization for stable training. No fully connected layers!
class Generator(nn.Module):
"""
DCGAN Generator for 64x64 RGB images.
Architecture: latent vector -> 4x4 -> 8x8 -> 16x16 -> 32x32 -> 64x64
"""
def __init__(self, latent_dim=100, ngf=64):
super().__init__()
self.latent_dim = latent_dim
self.main = nn.Sequential(
# Input: latent_dim x 1 x 1
nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# State: (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# State: (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# State: (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# State: (ngf) x 32 x 32
nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
nn.Tanh()
# Output: 3 x 64 x 64
)
def forward(self, noise):
return self.main(noise)
class Discriminator(nn.Module):
"""
DCGAN Discriminator for 64x64 RGB images.
Architecture: 64x64 -> 32x32 -> 16x16 -> 8x8 -> 4x4 -> 1
"""
def __init__(self, ndf=64):
super().__init__()
self.main = nn.Sequential(
# Input: 3 x 64 x 64
nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# State: (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# State: (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# State: (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# State: (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
# Output: 1 x 1 x 1
)
def forward(self, image):
return self.main(image).view(-1, 1)
# Initialize models
generator = Generator(latent_dim=100, ngf=64).to(device)
discriminator = Discriminator(ndf=64).to(device)
# Initialize weights (DCGAN paper recommendation)
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
generator.apply(weights_init)
discriminator.apply(weights_init)
print(f'Generator parameters: {sum(p.numel() for p in generator.parameters()):,}')
print(f'Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}')
Generator parameters: 3,576,704 Discriminator parameters: 2,765,568
Part 3: Training Setup¶
# Loss and optimizers
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# Fixed noise for visualization (100 samples)
fixed_noise = torch.randn(64, generator.latent_dim, 1, 1).to(device)
print('Training setup complete!')
print(f'Fixed noise shape: {fixed_noise.shape}')
Training setup complete! Fixed noise shape: torch.Size([64, 100, 1, 1])
# Training function
def train_epoch(generator, discriminator, loader, optimizer_g, optimizer_d, criterion, device):
generator.train()
discriminator.train()
d_losses = []
g_losses = []
for real_images, _ in tqdm(loader, desc='Training'):
batch_size = real_images.size(0)
real_images = real_images.to(device)
# Labels for real and fake
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# ============================================
# Train Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
# ============================================
optimizer_d.zero_grad()
# Real images
real_output = discriminator(real_images)
d_loss_real = criterion(real_output, real_labels)
# Fake images
noise = torch.randn(batch_size, generator.latent_dim, 1, 1).to(device)
fake_images = generator(noise)
fake_output = discriminator(fake_images.detach())
d_loss_fake = criterion(fake_output, fake_labels)
# Total discriminator loss
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
optimizer_d.step()
# ============================================
# Train Generator: maximize log(D(G(z)))
# ============================================
optimizer_g.zero_grad()
# Generate fake images
noise = torch.randn(batch_size, generator.latent_dim, 1, 1).to(device)
fake_images = generator(noise)
fake_output = discriminator(fake_images)
# Generator wants discriminator to think fakes are real
g_loss = criterion(fake_output, real_labels)
g_loss.backward()
optimizer_g.step()
d_losses.append(d_loss.item())
g_losses.append(g_loss.item())
return np.mean(d_losses), np.mean(g_losses)
Part 4: Train the GAN¶
Watch the generated faces improve progressively! Early epochs will show blurry faces, later epochs will show realistic features.
epochs = 20
sample_interval = 2 # Show samples every N epochs
history = {'d_loss': [], 'g_loss': []}
for epoch in range(epochs):
print(f'\n=== Epoch {epoch+1}/{epochs} ===')
d_loss, g_loss = train_epoch(generator, discriminator, train_loader,
optimizer_g, optimizer_d, criterion, device)
history['d_loss'].append(d_loss)
history['g_loss'].append(g_loss)
print(f'D Loss: {d_loss:.4f} | G Loss: {g_loss:.4f}')
# Generate samples at intervals
if (epoch + 1) % sample_interval == 0 or epoch == 0:
generator.eval()
with torch.no_grad():
fake_images = generator(fixed_noise)
grid = make_grid(fake_images, nrow=8, normalize=True, value_range=(-1, 1))
plt.figure(figsize=(10, 10))
plt.imshow(grid.permute(1, 2, 0).cpu())
plt.title(f'Generated Faces - Epoch {epoch+1}', fontsize=16)
plt.axis('off')
plt.tight_layout()
plt.show()
print('\nTraining complete!')
=== Epoch 1/20 ===
Training: 100%|██████████| 1272/1272 [01:34<00:00, 13.51it/s]
D Loss: 0.6442 | G Loss: 6.2513
=== Epoch 2/20 ===
Training: 100%|██████████| 1272/1272 [01:32<00:00, 13.79it/s]
D Loss: 0.6499 | G Loss: 3.5977
=== Epoch 3/20 ===
Training: 100%|██████████| 1272/1272 [01:33<00:00, 13.66it/s]
D Loss: 0.7304 | G Loss: 2.6784 === Epoch 4/20 ===
Training: 100%|██████████| 1272/1272 [01:32<00:00, 13.75it/s]
D Loss: 0.7767 | G Loss: 2.4050
=== Epoch 5/20 ===
Training: 100%|██████████| 1272/1272 [01:32<00:00, 13.70it/s]
D Loss: 0.7661 | G Loss: 2.3691 === Epoch 6/20 ===
Training: 100%|██████████| 1272/1272 [01:33<00:00, 13.61it/s]
D Loss: 0.7316 | G Loss: 2.3640
=== Epoch 7/20 ===
Training: 100%|██████████| 1272/1272 [01:32<00:00, 13.68it/s]
D Loss: 0.7087 | G Loss: 2.4317 === Epoch 8/20 ===
Training: 100%|██████████| 1272/1272 [01:33<00:00, 13.57it/s]
D Loss: 0.6845 | G Loss: 2.4593
=== Epoch 9/20 ===
Training: 100%|██████████| 1272/1272 [01:33<00:00, 13.61it/s]
D Loss: 0.6667 | G Loss: 2.5262 === Epoch 10/20 ===
Training: 100%|██████████| 1272/1272 [01:33<00:00, 13.66it/s]
D Loss: 0.6172 | G Loss: 2.5825
=== Epoch 11/20 ===
Training: 100%|██████████| 1272/1272 [01:34<00:00, 13.44it/s]
D Loss: 0.5901 | G Loss: 2.7086 === Epoch 12/20 ===
Training: 100%|██████████| 1272/1272 [01:32<00:00, 13.70it/s]
D Loss: 0.5598 | G Loss: 2.8499
=== Epoch 13/20 ===
Training: 100%|██████████| 1272/1272 [01:33<00:00, 13.57it/s]
D Loss: 0.4907 | G Loss: 3.0304 === Epoch 14/20 ===
Training: 100%|██████████| 1272/1272 [01:33<00:00, 13.64it/s]
D Loss: 0.4463 | G Loss: 3.2286
=== Epoch 15/20 ===
Training: 100%|██████████| 1272/1272 [01:33<00:00, 13.55it/s]
D Loss: 0.4580 | G Loss: 3.2955 === Epoch 16/20 ===
Training: 100%|██████████| 1272/1272 [01:33<00:00, 13.60it/s]
D Loss: 0.4306 | G Loss: 3.4094
=== Epoch 17/20 ===
Training: 100%|██████████| 1272/1272 [01:33<00:00, 13.67it/s]
D Loss: 0.4157 | G Loss: 3.5188 === Epoch 18/20 ===
Training: 100%|██████████| 1272/1272 [01:32<00:00, 13.69it/s]
D Loss: 0.4293 | G Loss: 3.5587
=== Epoch 19/20 ===
Training: 100%|██████████| 1272/1272 [01:32<00:00, 13.71it/s]
D Loss: 0.3659 | G Loss: 3.7131 === Epoch 20/20 ===
Training: 100%|██████████| 1272/1272 [01:33<00:00, 13.66it/s]
D Loss: 0.3593 | G Loss: 3.6743
Training complete!
# Plot training curves
plt.figure(figsize=(10, 5))
plt.plot(history['d_loss'], label='Discriminator Loss', linewidth=2)
plt.plot(history['g_loss'], label='Generator Loss', linewidth=2)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Training Loss Over Time', fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.show()
Part 5: Generate More Faces¶
Generate new random faces on demand!
def generate_faces(num_samples=16):
"""
Generate random faces.
Args:
num_samples: Number of faces to generate
"""
generator.eval()
with torch.no_grad():
noise = torch.randn(num_samples, generator.latent_dim, 1, 1).to(device)
generated = generator(noise)
grid = make_grid(generated, nrow=4, normalize=True, value_range=(-1, 1))
plt.figure(figsize=(8, 8))
plt.imshow(grid.permute(1, 2, 0).cpu())
plt.title(f'Generated Faces', fontsize=16)
plt.axis('off')
plt.tight_layout()
plt.show()
# Generate multiple batches
print('Generating random faces...')
for i in range(3):
print(f'\nBatch {i+1}:')
generate_faces(num_samples=16)
Generating random faces... Batch 1:
Batch 2:
Batch 3:
Part 6: Generate High-Resolution Grid¶
Create a large grid showing the variety of generated faces.
# Generate a large grid of faces
generator.eval()
num_faces = 64
with torch.no_grad():
noise = torch.randn(num_faces, generator.latent_dim, 1, 1).to(device)
generated_faces = generator(noise)
grid = make_grid(generated_faces, nrow=8, normalize=True, value_range=(-1, 1))
plt.figure(figsize=(15, 15))
plt.imshow(grid.permute(1, 2, 0).cpu())
plt.title('Generated Face Gallery (64 unique faces)', fontsize=18)
plt.axis('off')
plt.tight_layout()
plt.show()
Part 7: Latent Space Exploration¶
Interpolate between two random points in latent space for the same digit.
def interpolate_latent(num_steps=10):
"""
Interpolate between two random latent vectors to show smooth transitions.
"""
generator.eval()
# Two random starting points
z1 = torch.randn(1, generator.latent_dim, 1, 1).to(device)
z2 = torch.randn(1, generator.latent_dim, 1, 1).to(device)
interpolations = []
with torch.no_grad():
for alpha in torch.linspace(0, 1, num_steps):
z = (1 - alpha) * z1 + alpha * z2
img = generator(z)
interpolations.append(img)
interpolations = torch.cat(interpolations)
grid = make_grid(interpolations, nrow=num_steps, normalize=True, value_range=(-1, 1))
plt.figure(figsize=(15, 3))
plt.imshow(grid.permute(1, 2, 0).cpu())
plt.title(f'Latent Space Interpolation (smooth transitions)', fontsize=14)
plt.axis('off')
plt.tight_layout()
plt.show()
# Show multiple interpolations
print('Latent space interpolation - watch faces morph smoothly!')
for i in range(3):
print(f'\nInterpolation {i+1}:')
interpolate_latent(num_steps=10)
Latent space interpolation - watch faces morph smoothly! Interpolation 1:
Interpolation 2:
Interpolation 3: